import os
import torch
import yaml
import glob
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
def set_seed(seed):
torch.manual_seed(seed)
np.random.seed(seed)
# for cuda
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
set_seed(0)
def extract_files():
import google.colab
import zipfile
google.colab.drive.mount('/content/drive')
PROJECT_DIR = "/content/drive/MyDrive/thesis/data/"
zip_ref = zipfile.ZipFile(PROJECT_DIR + "fiveK.zip", 'r')
zip_ref.extractall(".")
zip_ref.close()
if 'google.colab' in str(get_ipython()):
extract_files()
config_path = "/content/drive/MyDrive/thesis/config.yaml"
else:
config_path = "../../config.yaml"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
cpu
try:
# Load configuration
with open(config_path, 'r') as config_file:
config = yaml.safe_load(config_file)
except:
raise FileNotFoundError(f"Config file not found at path: {config_path}")
loss_type = config['unetmodel']['loss']
depth = config['unetmodel']['depth']
base_checkpoint_path = f"{config['paths']['unetcheckpoints']}_{loss_type}_{depth}"
def load_best_checkpoint(checkpoint_dir):
# Check if the directory exists
if not os.path.exists(base_checkpoint_path):
print(f"No directory found: {checkpoint_dir}")
return None
# Get a list of all checkpoint files in the directory
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, f'unet_*.pth'))
# sort the checkpoint files according to the epoch number
checkpoint_files.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))
# Check if any checkpoint files are present
if not checkpoint_files:
print(f"No checkpoints found in the directory: {checkpoint_dir}")
return None
best_val_loss = float('inf')
epochs = []
train_losses = []
val_losses = []
index = 0
for checkpoint_file in checkpoint_files:
checkpoint = torch.load(checkpoint_file, map_location=torch.device(device))
epochs.append(checkpoint['epoch'])
train_losses.append(checkpoint['train_loss'])
index += 1
val_losses.append(checkpoint['val_loss'])
if checkpoint['train_loss'] < best_val_loss:
best_val_loss = checkpoint['train_loss']
best_checkpoint = checkpoint
return best_checkpoint, epochs, train_losses, val_losses
checkpoint, epochs, train_losses, val_losses = load_best_checkpoint(base_checkpoint_path)
# Draw the plot between epochs and accuracies
import matplotlib.pyplot as plt
plt.plot(epochs, train_losses, label='Training Loss')
plt.plot(epochs, val_losses, label='Validation Loss')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss')
plt.show()
class ConvBlock(torch.nn.Module):
def __init__(self, inchannels, outchannels, downscale=False, upscale=False):
super(ConvBlock, self).__init__()
self.down = torch.nn.MaxPool2d(2) if downscale else torch.nn.Identity()
self.conv1 = torch.nn.Conv2d(inchannels, outchannels, 3, padding=1)
self.bnorm1 = torch.nn.InstanceNorm2d(outchannels)
self.conv2 = torch.nn.Conv2d(outchannels, outchannels, 3, padding=1)
self.bnorm2 = torch.nn.InstanceNorm2d(outchannels)
self.up = torch.nn.Upsample(scale_factor=2) if upscale else torch.nn.Identity()
def forward(self, x):
x = self.down(x)
x = torch.nn.functional.relu(self.bnorm1(self.conv1(x)))
x = torch.nn.functional.relu(self.bnorm2(self.conv2(x)))
x = self.up(x)
return x
class UNet(torch.nn.Module):
def __init__(self, classes, depth):
super(UNet, self).__init__()
self.encoder = torch.nn.ModuleList()
channels = [3] + [64 * (2 ** i) for i in range(depth)]
for i in range(depth):
self.encoder.append(ConvBlock(channels[i], channels[i + 1], downscale=(i > 0)))
self.embedding_after_encoder = torch.nn.Embedding(classes, channels[-1])
self.bottleneck = ConvBlock(channels[-1], channels[-1], downscale=True, upscale=True)
self.decoder = torch.nn.ModuleList()
self.decoder_embedding = torch.nn.ModuleList()
channels[0] = 64
for i in range(depth - 1, -1, -1):
self.decoder.append(ConvBlock(2 * channels[i + 1], channels[i], upscale=(i > 0)))
self.decoder_embedding.append(torch.nn.Embedding(classes, 2 * channels[i] if i > 0 else channels[i]))
self.output = torch.nn.Sequential(
torch.nn.Conv2d(channels[0], 3, 1),
torch.nn.Sigmoid()
)
def forward(self, x, label):
skip = []
for mod in self.encoder:
x = mod(x)
skip.append(x)
embedded_label = self.embedding_after_encoder(label).unsqueeze(-1).unsqueeze(-1)
x = x + embedded_label
x = self.bottleneck(x)
for mod, emb in zip(self.decoder, self.decoder_embedding):
y = skip.pop()
# add embedding with the decoder
x = x + emb(label).unsqueeze(-1).unsqueeze(-1)
x = torch.cat([x, y], 1)
x = mod(x)
x = self.output(x)
return x
depth = config['unetmodel']['depth']
net = UNet(classes=2, depth=3)
net = net.to(device)
net.load_state_dict(checkpoint['state_dict'])
<All keys matched successfully>
print(checkpoint['epoch'])
39
classes = ['expB', 'expC']
test_images = ["0120", "0173", "0612", "0001"]
MEAN = [0.2279, 0.2017, 0.1825]
STD = [0.1191, 0.1092, 0.1088]
raw = torch.Tensor().to(device)
retouched_img = torch.Tensor().to(device)
labels = torch.Tensor().type(torch.int).to(device)
for test_image in test_images:
for cls in classes:
raw_img_path = os.path.join(config['paths']['data'], 'raw', test_image + '.png')
retouched_img_path = os.path.join(config['paths']['data'], cls, test_image + '.png')
test_tr = transforms.Compose([
transforms.ToTensor(),
transforms.CenterCrop(224),
transforms.Normalize(mean=MEAN, std=STD)
])
raw = torch.cat((raw, test_tr(Image.open(raw_img_path)).unsqueeze(0)))
retouched_img = torch.cat((retouched_img, test_tr(Image.open(retouched_img_path)).unsqueeze(0)))
labels = torch.cat((labels, torch.tensor([classes.index(cls)])))
with torch.no_grad():
net.eval()
outputs = net(raw, labels)
rows = len(test_images) * len(classes)
fig, axes = plt.subplots(rows, 3, figsize=(20, 53))
row = 0
for r_img, label, re_img, gi in zip(raw, labels, retouched_img, outputs):
unnormalized_raw_img = r_img.permute(1, 2, 0) * torch.tensor(STD) + torch.tensor(MEAN)
unnormalized_retouched_img = re_img.permute(1, 2, 0) * torch.tensor(STD) + torch.tensor(MEAN)
axes[row, 0].imshow(unnormalized_raw_img.numpy().clip(0, 1))
axes[row, 0].set_title('Raw')
axes[row, 1].imshow(unnormalized_retouched_img.numpy().clip(0, 1))
axes[row, 1].set_title(f'RI for {classes[label]}')
axes[row, 2].imshow(gi.permute(1, 2, 0).cpu().numpy().clip(0, 1))
axes[row, 2].set_title(f'GI for {classes[label]}')
row += 1
for ax in axes.flatten():
ax.axis('off')
plt.tight_layout()
plt.show()